{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "fcc28e11-c132-4816-89f8-60143510df7d",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-07-12T01:50:44.152064Z",
     "iopub.status.busy": "2022-07-12T01:50:44.151480Z",
     "iopub.status.idle": "2022-07-12T01:50:55.149411Z",
     "shell.execute_reply": "2022-07-12T01:50:55.144781Z",
     "shell.execute_reply.started": "2022-07-12T01:50:44.152015Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Populating the interactive namespace from numpy and matplotlib\n"
     ]
    }
   ],
   "source": [
    "%pylab inline\n",
    "\n",
    "import os\n",
    "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"\"\n",
    "\n",
    "import sys\n",
    "import codecs\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "\n",
    "\n",
    "def read_data(path='data_v2/'):\n",
    "    vid_info = pd.read_csv(os.path.join(path, 'vid_info.csv'))\n",
    "    vid_info['stars'] = vid_info['stars'].apply(eval)\n",
    "    vid_info['tags'] = vid_info['tags'].apply(eval)\n",
    "    vid_info['key_word'] = vid_info['key_word'].apply(eval)\n",
    "    vid_info.sort_values(by=['cid', 'serialno'], inplace=True)\n",
    "    # vid_info.set_index('vid', inplace=True)\n",
    "\n",
    "    seq_train = pd.read_csv(os.path.join(path, 'main_vv_seq_train.csv'))\n",
    "    seq_train = seq_train.sort_values(by=['did', 'seq_no'])\n",
    "    seq_train.reset_index(inplace=True)\n",
    "\n",
    "    candidate_items = pd.read_csv(os.path.join(path, 'candidate_items_A.csv'))\n",
    "    return vid_info, seq_train, candidate_items\n",
    "\n",
    "\n",
    "vid_info, seq_train, candidate_items = read_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "bbab2440-be3a-4c67-9c5f-cf3e7d05b5bd",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-07-12T01:50:55.150671Z",
     "iopub.status.busy": "2022-07-12T01:50:55.150493Z",
     "iopub.status.idle": "2022-07-12T01:50:55.153717Z",
     "shell.execute_reply": "2022-07-12T01:50:55.153314Z",
     "shell.execute_reply.started": "2022-07-12T01:50:55.150655Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "\n",
    "import torch\n",
    "from torch.utils.data import DataLoader, Dataset\n",
    "\n",
    "from transformers import AutoConfig, AutoModelForMaskedLM\n",
    "from transformers import Trainer, TrainingArguments\n",
    "from transformers import DataCollatorForLanguageModeling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "16691cb0-dc12-4b7d-9265-83071c477b58",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-07-12T01:50:55.154742Z",
     "iopub.status.busy": "2022-07-12T01:50:55.154580Z",
     "iopub.status.idle": "2022-07-12T01:51:04.471436Z",
     "shell.execute_reply": "2022-07-12T01:51:04.470760Z",
     "shell.execute_reply.started": "2022-07-12T01:50:55.154727Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "loading configuration file https://huggingface.co/bert-base-uncased/resolve/main/config.json from cache at /home/lyz/.cache/huggingface/transformers/3c61d016573b14f7f008c02c4e51a366c67ab274726fe2910691e2a761acf43e.37395cee442ab11005bcd270f3c34464dc1704b715b5d7d52b1a461abe3b9e4e\n",
      "Model config BertConfig {\n",
      "  \"_name_or_path\": \"bert-base-uncased\",\n",
      "  \"architectures\": [\n",
      "    \"BertForMaskedLM\"\n",
      "  ],\n",
      "  \"attention_probs_dropout_prob\": 0.1,\n",
      "  \"classifier_dropout\": null,\n",
      "  \"gradient_checkpointing\": false,\n",
      "  \"hidden_act\": \"gelu\",\n",
      "  \"hidden_dropout_prob\": 0.1,\n",
      "  \"hidden_size\": 768,\n",
      "  \"initializer_range\": 0.02,\n",
      "  \"intermediate_size\": 3072,\n",
      "  \"layer_norm_eps\": 1e-12,\n",
      "  \"max_position_embeddings\": 512,\n",
      "  \"model_type\": \"bert\",\n",
      "  \"num_attention_heads\": 12,\n",
      "  \"num_hidden_layers\": 12,\n",
      "  \"pad_token_id\": 0,\n",
      "  \"position_embedding_type\": \"absolute\",\n",
      "  \"transformers_version\": \"4.18.0\",\n",
      "  \"type_vocab_size\": 2,\n",
      "  \"use_cache\": true,\n",
      "  \"vocab_size\": 30522\n",
      "}\n",
      "\n",
      "loading file https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt from cache at /home/lyz/.cache/huggingface/transformers/45c3f7a79a80e1cf0a489e5c62b43f173c15db47864303a55d623bb3c96f72a5.d789d64ebfe299b0e416afc4a169632f903f693095b4629a7ea271d5a0cf2c99\n",
      "loading file https://huggingface.co/bert-base-uncased/resolve/main/tokenizer.json from cache at /home/lyz/.cache/huggingface/transformers/534479488c54aeaf9c3406f647aa2ec13648c06771ffe269edabebd4c412da1d.7f2721073f19841be16f41b0a70b600ca6b880c8f3df6f3535cbc704371bdfa4\n",
      "loading file https://huggingface.co/bert-base-uncased/resolve/main/added_tokens.json from cache at None\n",
      "loading file https://huggingface.co/bert-base-uncased/resolve/main/special_tokens_map.json from cache at None\n",
      "loading file https://huggingface.co/bert-base-uncased/resolve/main/tokenizer_config.json from cache at /home/lyz/.cache/huggingface/transformers/c1d7f0a763fb63861cc08553866f1fc3e5a6f4f07621be277452d26d71303b7e.20430bd8e10ef77a7d2977accefe796051e01bc2fc4aa146bc862997a1a15e79\n",
      "loading configuration file https://huggingface.co/bert-base-uncased/resolve/main/config.json from cache at /home/lyz/.cache/huggingface/transformers/3c61d016573b14f7f008c02c4e51a366c67ab274726fe2910691e2a761acf43e.37395cee442ab11005bcd270f3c34464dc1704b715b5d7d52b1a461abe3b9e4e\n",
      "Model config BertConfig {\n",
      "  \"_name_or_path\": \"bert-base-uncased\",\n",
      "  \"architectures\": [\n",
      "    \"BertForMaskedLM\"\n",
      "  ],\n",
      "  \"attention_probs_dropout_prob\": 0.1,\n",
      "  \"classifier_dropout\": null,\n",
      "  \"gradient_checkpointing\": false,\n",
      "  \"hidden_act\": \"gelu\",\n",
      "  \"hidden_dropout_prob\": 0.1,\n",
      "  \"hidden_size\": 768,\n",
      "  \"initializer_range\": 0.02,\n",
      "  \"intermediate_size\": 3072,\n",
      "  \"layer_norm_eps\": 1e-12,\n",
      "  \"max_position_embeddings\": 512,\n",
      "  \"model_type\": \"bert\",\n",
      "  \"num_attention_heads\": 12,\n",
      "  \"num_hidden_layers\": 12,\n",
      "  \"pad_token_id\": 0,\n",
      "  \"position_embedding_type\": \"absolute\",\n",
      "  \"transformers_version\": \"4.18.0\",\n",
      "  \"type_vocab_size\": 2,\n",
      "  \"use_cache\": true,\n",
      "  \"vocab_size\": 30522\n",
      "}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "2993a9a2-9dd7-4cee-932d-323494f98a16",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-07-12T01:51:23.380213Z",
     "iopub.status.busy": "2022-07-12T01:51:23.379678Z",
     "iopub.status.idle": "2022-07-12T01:51:23.594042Z",
     "shell.execute_reply": "2022-07-12T01:51:23.593576Z",
     "shell.execute_reply.started": "2022-07-12T01:51:23.380166Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "13406"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.add_tokens(candidate_items['vid'].tolist())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "674a8319-4238-4936-9844-df4384eb9868",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-07-11T23:16:47.917708Z",
     "iopub.status.busy": "2022-07-11T23:16:47.917492Z",
     "iopub.status.idle": "2022-07-11T23:16:47.923083Z",
     "shell.execute_reply": "2022-07-11T23:16:47.922557Z",
     "shell.execute_reply.started": "2022-07-11T23:16:47.917690Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# 简单自定义dataset\n",
    "class PreTrainDataset(Dataset):\n",
    "    def __init__(self, data_list, tokenizer, max_seq_len):\n",
    "        super(PreTrainDataset, self).__init__()\n",
    "        self.data_list = data_list\n",
    "        self.len = len(data_list)\n",
    "        self.tokenizer = tokenizer\n",
    "        self.max_seq_len = max_seq_len\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        example = self.data_list[index]\n",
    "        data = self.tokenizer.encode_plus(' '.join(example), return_token_type_ids=True, padding='max_length', truncation=True,\n",
    "                                          return_attention_mask=True, max_length=self.max_seq_len)\n",
    "        return {'input_ids':  torch.tensor(data['input_ids'][:self.max_seq_len], dtype=torch.long),\n",
    "                'token_type_ids':  torch.tensor(data['token_type_ids'][:self.max_seq_len], dtype=torch.long),\n",
    "                'attention_mask':  torch.tensor(data['attention_mask'][:self.max_seq_len], dtype=torch.long)}\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.len"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "6f86fde0-1ab9-459a-b5e7-e06bb329d125",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-07-12T01:52:48.153163Z",
     "iopub.status.busy": "2022-07-12T01:52:48.152575Z",
     "iopub.status.idle": "2022-07-12T01:52:53.084593Z",
     "shell.execute_reply": "2022-07-12T01:52:53.083759Z",
     "shell.execute_reply.started": "2022-07-12T01:52:48.153115Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "seq_vid_list = seq_train.groupby(['did'])['vid'].apply(list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "adffe4c8-ba49-4f90-8a06-2a2c89cd0d8a",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-07-12T01:52:53.085819Z",
     "iopub.status.busy": "2022-07-12T01:52:53.085640Z",
     "iopub.status.idle": "2022-07-12T01:52:53.100350Z",
     "shell.execute_reply": "2022-07-12T01:52:53.099388Z",
     "shell.execute_reply.started": "2022-07-12T01:52:53.085803Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "train_dataset = PreTrainDataset(\n",
    "    seq_vid_list[:-2000],\n",
    "    tokenizer,\n",
    "    16\n",
    ")\n",
    "\n",
    "val_dataset = PreTrainDataset(\n",
    "    seq_vid_list[-2000:],\n",
    "    tokenizer,\n",
    "    16\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a1de888d-c5d3-4d0d-ba0e-9ae30cf280ab",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-07-11T23:16:52.790327Z",
     "iopub.status.busy": "2022-07-11T23:16:52.790136Z",
     "iopub.status.idle": "2022-07-11T23:16:58.545925Z",
     "shell.execute_reply": "2022-07-11T23:16:58.545415Z",
     "shell.execute_reply.started": "2022-07-11T23:16:52.790311Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "\n",
    "# 加载模型\n",
    "config = AutoConfig.from_pretrained('bert-base-uncased')\n",
    "config.vocab_size = len(tokenizer)\n",
    "model = AutoModelForMaskedLM.from_config(config)\n",
    "model.resize_token_embeddings(len(tokenizer))\n",
    "\n",
    "model.to(device)\n",
    "# 定义trainer\n",
    "training_args = TrainingArguments(\n",
    "                          output_dir='pretrain_bert',\n",
    "                          overwrite_output_dir=True,\n",
    "                          do_train=True, \n",
    "                          do_eval=True,\n",
    "                          per_device_train_batch_size=256,\n",
    "                          per_device_eval_batch_size=256,\n",
    "                          evaluation_strategy='steps',\n",
    "                          logging_steps=100,\n",
    "                          eval_steps = None,\n",
    "                          prediction_loss_only=True,\n",
    "                          learning_rate = 1e-5,\n",
    "                          weight_decay=0.01,\n",
    "                          adam_epsilon = 1e-8,\n",
    "                          max_grad_norm = 1.0,\n",
    "                          num_train_epochs = 10,\n",
    "                          save_steps = 200,\n",
    "                        push_to_hub=False\n",
    ")\n",
    "\n",
    "data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=train_dataset,\n",
    "    eval_dataset=val_dataset,\n",
    "    data_collator=data_collator,\n",
    ")\n",
    "\n",
    "# 分类 top6"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11bef5af-f207-41c4-bebe-6ab3ed2744c7",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-07-11T23:16:58.546819Z",
     "iopub.status.busy": "2022-07-11T23:16:58.546646Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/lyz/.local/lib/python3.6/site-packages/transformers/optimization.py:309: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
      "  FutureWarning,\n",
      "***** Running training *****\n",
      "  Num examples = 168909\n",
      "  Num Epochs = 10\n",
      "  Instantaneous batch size per device = 256\n",
      "  Total train batch size (w. parallel, distributed & accumulation) = 256\n",
      "  Gradient Accumulation steps = 1\n",
      "  Total optimization steps = 6600\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='417' max='6600' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [ 417/6600 06:48 < 1:41:21, 1.02 it/s, Epoch 0.63/10]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>100</td>\n",
       "      <td>8.059800</td>\n",
       "      <td>6.861248</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>200</td>\n",
       "      <td>6.209700</td>\n",
       "      <td>5.313494</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>300</td>\n",
       "      <td>4.936200</td>\n",
       "      <td>4.277523</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>400</td>\n",
       "      <td>4.064000</td>\n",
       "      <td>3.528285</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "***** Running Evaluation *****\n",
      "  Num examples = 2000\n",
      "  Batch size = 256\n",
      "***** Running Evaluation *****\n",
      "  Num examples = 2000\n",
      "  Batch size = 256\n",
      "Saving model checkpoint to pretrain_bert/checkpoint-200\n",
      "Configuration saved in pretrain_bert/checkpoint-200/config.json\n",
      "Model weights saved in pretrain_bert/checkpoint-200/pytorch_model.bin\n",
      "***** Running Evaluation *****\n",
      "  Num examples = 2000\n",
      "  Batch size = 256\n",
      "***** Running Evaluation *****\n",
      "  Num examples = 2000\n",
      "  Batch size = 256\n",
      "Saving model checkpoint to pretrain_bert/checkpoint-400\n",
      "Configuration saved in pretrain_bert/checkpoint-400/config.json\n",
      "Model weights saved in pretrain_bert/checkpoint-400/pytorch_model.bin\n"
     ]
    }
   ],
   "source": [
    "trainer.train()\n",
    "trainer.save_model()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "id": "cd624724-5331-4367-8461-e05866896141",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-07-12T02:06:27.150863Z",
     "iopub.status.busy": "2022-07-12T02:06:27.150264Z",
     "iopub.status.idle": "2022-07-12T02:06:28.588682Z",
     "shell.execute_reply": "2022-07-12T02:06:28.588260Z",
     "shell.execute_reply.started": "2022-07-12T02:06:27.150813Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "loading configuration file ./pretrain_bert/checkpoint-6600/config.json\n",
      "Model config BertConfig {\n",
      "  \"_name_or_path\": \"./pretrain_bert/checkpoint-6600/\",\n",
      "  \"architectures\": [\n",
      "    \"BertForMaskedLM\"\n",
      "  ],\n",
      "  \"attention_probs_dropout_prob\": 0.1,\n",
      "  \"classifier_dropout\": null,\n",
      "  \"gradient_checkpointing\": false,\n",
      "  \"hidden_act\": \"gelu\",\n",
      "  \"hidden_dropout_prob\": 0.1,\n",
      "  \"hidden_size\": 768,\n",
      "  \"initializer_range\": 0.02,\n",
      "  \"intermediate_size\": 3072,\n",
      "  \"layer_norm_eps\": 1e-12,\n",
      "  \"max_position_embeddings\": 512,\n",
      "  \"model_type\": \"bert\",\n",
      "  \"num_attention_heads\": 12,\n",
      "  \"num_hidden_layers\": 12,\n",
      "  \"pad_token_id\": 0,\n",
      "  \"position_embedding_type\": \"absolute\",\n",
      "  \"torch_dtype\": \"float32\",\n",
      "  \"transformers_version\": \"4.18.0\",\n",
      "  \"type_vocab_size\": 2,\n",
      "  \"use_cache\": true,\n",
      "  \"vocab_size\": 30522\n",
      "}\n",
      "\n",
      "loading weights file ./pretrain_bert/checkpoint-6600/pytorch_model.bin\n",
      "All model checkpoint weights were used when initializing BertForMaskedLM.\n",
      "\n",
      "All the weights of BertForMaskedLM were initialized from the model checkpoint at ./pretrain_bert/checkpoint-6600/.\n",
      "If your task is similar to the task the model of the checkpoint was trained on, you can already use BertForMaskedLM for predictions without further training.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Embedding(43928, 768)"
      ]
     },
     "execution_count": 58,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from transformers import AutoTokenizer\n",
    "\n",
    "import torch\n",
    "from torch.utils.data import DataLoader, Dataset\n",
    "\n",
    "from transformers import AutoConfig, AutoModelForMaskedLM\n",
    "from transformers import Trainer, TrainingArguments\n",
    "from transformers import DataCollatorForLanguageModeling\n",
    "\n",
    "# tokenizer = AutoTokenizer.from_pretrained('./pretrain_bert/checkpoint-6600/')\n",
    "model = AutoModelForMaskedLM.from_pretrained('./pretrain_bert/checkpoint-6600/')\n",
    "model.resize_token_embeddings(len(tokenizer))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "id": "7c5eef0a-f2b3-49d6-aac0-95476f4735b2",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-07-12T02:08:02.148130Z",
     "iopub.status.busy": "2022-07-12T02:08:02.147522Z",
     "iopub.status.idle": "2022-07-12T02:08:02.155974Z",
     "shell.execute_reply": "2022-07-12T02:08:02.155298Z",
     "shell.execute_reply.started": "2022-07-12T02:08:02.148079Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# 简单自定义dataset\n",
    "class TestDataset(Dataset):\n",
    "    def __init__(self, data_list, tokenizer, max_seq_len):\n",
    "        super(TestDataset, self).__init__()\n",
    "        self.data_list = data_list\n",
    "        self.len = len(data_list)\n",
    "        self.tokenizer = tokenizer\n",
    "        self.max_seq_len = max_seq_len\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        example = self.data_list[index]\n",
    "        example[-1] = '[MASK]'\n",
    "        data = self.tokenizer.encode_plus(' '.join(example[-14:]), return_token_type_ids=True, padding='max_length', truncation=True,\n",
    "                                          return_attention_mask=True, max_length=self.max_seq_len)\n",
    "        return {'input_ids':  torch.tensor(data['input_ids'][:self.max_seq_len], dtype=torch.long),\n",
    "                'token_type_ids':  torch.tensor(data['token_type_ids'][:self.max_seq_len], dtype=torch.long),\n",
    "                'attention_mask':  torch.tensor(data['attention_mask'][:self.max_seq_len], dtype=torch.long)}\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.len"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "id": "f88e2f90-b73b-4d98-8e33-998e355b8c80",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-07-12T02:08:02.209022Z",
     "iopub.status.busy": "2022-07-12T02:08:02.208478Z",
     "iopub.status.idle": "2022-07-12T02:08:03.294554Z",
     "shell.execute_reply": "2022-07-12T02:08:03.293071Z",
     "shell.execute_reply.started": "2022-07-12T02:08:02.208974Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "val_dataset = TestDataset(\n",
    "    seq_vid_list[-2000:],\n",
    "    tokenizer,\n",
    "    16\n",
    ")\n",
    "\n",
    "val_dataloader = DataLoader(val_dataset, shuffle=False, batch_size=16)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "id": "dd652a96-e08f-40fa-8b73-9b5cb3f56dcb",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-07-12T02:08:04.149544Z",
     "iopub.status.busy": "2022-07-12T02:08:04.148954Z",
     "iopub.status.idle": "2022-07-12T02:08:04.161797Z",
     "shell.execute_reply": "2022-07-12T02:08:04.161156Z",
     "shell.execute_reply.started": "2022-07-12T02:08:04.149494Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "for data in val_dataloader:\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "id": "8efed761-c90c-477e-bf55-5b356cabe35a",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-07-12T02:08:04.353534Z",
     "iopub.status.busy": "2022-07-12T02:08:04.352964Z",
     "iopub.status.idle": "2022-07-12T02:08:04.533427Z",
     "shell.execute_reply": "2022-07-12T02:08:04.532953Z",
     "shell.execute_reply.started": "2022-07-12T02:08:04.353472Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "pred = model(input_ids=data['input_ids'], attention_mask=data['attention_mask']).logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "id": "e3764d16-263c-434d-b26a-552b0ab9ac1d",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-07-12T02:08:06.150623Z",
     "iopub.status.busy": "2022-07-12T02:08:06.149719Z",
     "iopub.status.idle": "2022-07-12T02:08:06.157860Z",
     "shell.execute_reply": "2022-07-12T02:08:06.156963Z",
     "shell.execute_reply.started": "2022-07-12T02:08:06.150542Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "pred = pred.data.numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "id": "6109f828-9c86-4e54-9b7f-a240bd23b7d0",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-07-12T02:09:20.152438Z",
     "iopub.status.busy": "2022-07-12T02:09:20.151851Z",
     "iopub.status.idle": "2022-07-12T02:09:20.157443Z",
     "shell.execute_reply": "2022-07-12T02:09:20.156323Z",
     "shell.execute_reply.started": "2022-07-12T02:09:20.152388Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "idx = 0\n",
    "for input_ids in data['input_ids'].data.numpy():\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "id": "d2f675be-a1b2-47c1-8824-be2fc678045b",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-07-12T02:09:58.438802Z",
     "iopub.status.busy": "2022-07-12T02:09:58.438206Z",
     "iopub.status.idle": "2022-07-12T02:09:58.446763Z",
     "shell.execute_reply": "2022-07-12T02:09:58.445744Z",
     "shell.execute_reply.started": "2022-07-12T02:09:58.438753Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2683"
      ]
     },
     "execution_count": 88,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "id": "7893c9f4-c793-49f3-8498-67c26e5fec2c",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-07-12T02:10:18.413878Z",
     "iopub.status.busy": "2022-07-12T02:10:18.413275Z",
     "iopub.status.idle": "2022-07-12T02:10:18.422236Z",
     "shell.execute_reply": "2022-07-12T02:10:18.421171Z",
     "shell.execute_reply.started": "2022-07-12T02:10:18.413826Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'##9'"
      ]
     },
     "execution_count": 89,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.decode(pred[0, np.where(input_ids == 103)[0][0]].argmax())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2497c56-91d9-4b03-b5e4-a50fcdad48b8",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.6 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
